import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import warnings

from scipy.io import loadmat

from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, accuracy_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold
from torch import nn
from torch.utils.data import SubsetRandomSampler, DataLoader, Dataset

from NeruoH_TGL.model import model_all


def MaxMinNormalization(x, Max, Min):
    x = (x - Min) / (Max - Min)
    return x


warnings.filterwarnings("ignore")
seed = 7
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)  #

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_windows = 8
avg_acc = 0
avg_spe = 0
avg_recall = 0
avg_f1 = 0
avg_auc = 0
pre_ten = []
label_ten = []
gailv_ten = []
kk = 10






def stest(model, datasets_test):
    eval_loss = 0
    eval_acc = 0
    pre_all = []
    labels_all = []
    gailv_all = []
    pro_all = []

    model.eval()

    for net, data_feas, label in datasets_test:
        net, data_feas, label = net.to(DEVICE), data_feas.to(DEVICE), label.to(DEVICE)
        net = net.float()
        data_feas = data_feas.float()
        label = label.long()
        outs, loss_CC, loss_REC = model(net, data_feas)  # torch.Size([4, 3])

        losss = F.nll_loss(out, label) + lam * loss_CC + alpha * loss_REC

        eval_loss += float(losss)

        gailv, pred = outs.max(1)
        # print(len(pred))
        num_correct = (pred == label).sum()
        acc = int(num_correct) / net.shape[0]
        eval_acc += acc
        pre = pred.cpu().detach().numpy()
        pre_all.extend(pre)
        label_true = label.cpu().detach().numpy()
        labels_all.extend(label_true)
        pro_all.extend(outs[:, 1].cpu().detach().numpy())
    # nets_learn = np.concatenate(net_learn_all, axis=0)

    tn, fp, fn, tp = confusion_matrix(labels_all, pre_all).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    eval_acc_epoch = accuracy_score(labels_all, pre_all)
    precision = precision_score(labels_all, pre_all)
    recall = recall_score(labels_all, pre_all)
    f1 = f1_score(labels_all, pre_all)
    my_auc = roc_auc_score(labels_all, pro_all)

    return eval_loss, eval_acc, eval_acc_epoch, precision, recall, f1, my_auc, pre_all, labels_all, pro_all



log = open('ADNI_Win6_NC_MCI.txt', mode='a', encoding='utf-8')

for win in [6]:
    for lr in [7e-4]:
        for lam in [ 1e-3]:
            for alpha in [1e-4]:
                a = loadmat('...')  # ['__header__', '__version__', '__globals__', 'rs_fmri', 'label']
                keysa = list(a.keys())
                fdata = a['rs_fmri']
                fdata = fdata[0:408, :, :]
                print(fdata.shape)
                labels = a['label'][0]
                print(labels.shape)
                labels = labels[0:408]
                labels[labels == 2] = 1



                index = [i for i in range(fdata.shape[0])]
                np.random.shuffle(index)
                fdata = fdata[index]
                labels = labels[index]


                
                def create_DFCN(dataset, num_window, yuzhi):
                    nets_all = []
                    fmris_all = []
                    win_length = 140 - num_window * 10
                    for i in range(dataset.shape[0]):
                        nets = []
                        fmri_sub = []
                        datas = dataset[i]
                        for j in range(num_window):
                            window = datas[:, j * 10: j * 10 + win_length]
                            fmri_sub.append(window)
                            net = np.corrcoef(window)
                            net = np.abs(net)
                            nets.append(net)
                        nets_all.append(nets)
                        fmris_all.append(fmri_sub)
                    nets_all = np.array(nets_all)
                    fmris_all = np.array(fmris_all)
                    nets_all[nets_all < yuzhi] = 0
                    return nets_all, fmris_all


                nets_all, fmris_all = create_DFCN(fdata, win, 0.6)
                print(fmris_all.shape)
                print(fmris_all.shape[3])


                
                class Dianxian(Dataset):
                    def __init__(self):
                        super(Dianxian, self).__init__()
                        self.feas = fmris_all
                        self.nets = nets_all
                        self.label = labels

                    def __getitem__(self, item):
                        fea = self.feas[item]
                        net = self.nets[item]
                        label = self.label[item]
                        return fea, net, label

                    def __len__(self):
                        return self.feas.shape[0]



                i = 0
                test_acc = []
                test_pre = []
                test_recall = []
                test_f1 = []
                test_auc = []
                test_sens = []
                test_spec = []
                label_ten = []
                pro_ten = []
                scores = []
                train_ratio = 0.8
                dataset = Dianxian()
                skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
                for fold, (train_idx, test_idx) in enumerate(skf.split(X=dataset.feas, y=dataset.label)):
                   
                    train_sampler = SubsetRandomSampler(train_idx)
                    test_sampler = SubsetRandomSampler(test_idx)

                    train_size = int(train_ratio * len(train_idx))
                    valid_size = len(train_idx) - train_size
                    train_indices, valid_indices = train_idx[:train_size], train_idx[train_size:]
                    datasets_train = DataLoader(dataset, batch_size=8, shuffle=False,
                                                sampler=SubsetRandomSampler(train_indices),drop_last=True)
                    datasets_valid = DataLoader(dataset, batch_size=8, shuffle=False,
                                                sampler=SubsetRandomSampler(valid_indices),drop_last=True)
                    datasets_test = DataLoader(dataset, batch_size=8, shuffle=False,
                                               sampler=SubsetRandomSampler(test_idx),drop_last=True)
                    epoch = 300
                    losses = []
                    acces = []
                    eval_losses = []
                    eval_acces = []
                    patiences = 80
                    min_acc = 0

                    model = model_all(fmris_all.shape[3], 8, win)
                    model.to(DEVICE)
                    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # 0.005
                    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
                    for e in range(epoch):
                        train_loss = 0
                        train_acc = 0
                        model.train()
                        for ot_net, cheb, label in datasets_train:
                            ot_net, cheb, label = ot_net.to(DEVICE), cheb.to(DEVICE), label.to(DEVICE)
                            # 前向传播

                            ot_net = ot_net.float()
                            cheb = cheb.float()
                            label = label.long()
                            out, loss_CC, loss_REC = model(ot_net, cheb)


                            loss = F.nll_loss(out, label) + lam * loss_CC + alpha * loss_REC

                            optimizer.zero_grad()
                            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                            loss.backward()
                            optimizer.step()
                            train_loss += float(loss)
                            _, pred = out.max(1)
                            num_correct = (pred == label).sum()
                            acc = num_correct / ot_net.shape[0]
                            train_acc += acc
                        # scheduler.step()
                        plt.show()

                        losses.append(train_loss / len(datasets_train))
                        acces.append(train_acc / len(datasets_train))

                        eval_loss, eval_acc, eval_acc_epoch, precision, recall, f1, my_auc, pre_all, labels_all, pro_all = stest(
                            model,datasets_valid)
                        if eval_acc_epoch > min_acc:
                            min_acc = eval_acc_epoch
                            torch.save(model.state_dict(), './latest' + str(i) + '.pth')
                            print("Model saved at epoch{}".format(e))
                            pre_gd = precision
                            recall_gd = recall
                            f1_gd = f1
                            auc_gd = my_auc
                            labels_all_gd = labels_all
                            pro_all_gd = pro_all

                            patience = 0
                        else:
                            patience += 1
                        if patience > patiences:
                            break
                        eval_losses.append(eval_loss / len(datasets_valid))
                        eval_acces.append(eval_acc / len(datasets_valid))
                        print(
                            'i:{},epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f},precision : {'
                            ':.6f},recall : {:.6f},f1 : {:.6f},my_auc : {:.6f}  '
                            .format(i, e, train_loss / len(datasets_train), train_acc / len(datasets_train),
                                    eval_loss / len(datasets_valid), eval_acc_epoch, precision, recall, f1, my_auc))

                    model_test = model_all(fmris_all.shape[3], 8, win)
                    model_test = model_test.to(DEVICE)
                    model_test.load_state_dict(torch.load('./latest' + str(i) + '.pth'))  # 84.3750
                    eval_loss, eval_acc, eval_acc_epoch, precision, recall, f1, my_auc, pre_all, labels_all, pro_all= stest(
                        model_test, datasets_test)

                    test_acc.append(eval_acc_epoch)
                    test_pre.append(precision)
                    test_recall.append(recall)
                    test_f1.append(f1)
                    test_auc.append(my_auc)

                    label_ten.extend(labels_all)
                    pro_ten.extend(pro_all)



                    i = i + 1
                print('****************************************************************', file=log)
                print('lam', lam, 'alpha', alpha, file=log)
                print("win", win, "lr", lr, "test_acc", test_acc, file=log)
                print("win", win, "lr", lr, "test_pre", test_pre, file=log)
                print("win", win,"lr", lr, "test_recall", test_recall, file=log)
                print("win", win, "lr", lr, "test_f1", test_f1, file=log)
                print("win", win, "lr", lr, "test_auc", test_auc, file=log)

                avg_acc = sum(test_acc) / 10
                avg_spe = sum(test_pre) / 10
                avg_recall = sum(test_recall) / 10
                avg_f1 = sum(test_f1) / 10
                avg_auc = sum(test_auc) / 10


                print("***************************************************************", file=log)
                print("win", win, "lr", lr, 'acc', avg_acc, file=log)
                print("win", win, "lr", lr, 'spe', avg_spe, file=log)
                print("win", win,"lr", lr, 'recall', avg_recall, file=log)
                print("win", win, "lr", lr, 'f1', avg_f1, file=log)
                print("win", win, "lr", lr, 'auc', avg_auc, file=log)
                # print("yu", yu, "lr", lr, "test_sens", avg_sens, file=log)
                # print("yu", yu, "lr", lr, "test_spec", avg_spec, file=log)
                print("win", win, "lr", lr, 'label_ten', label_ten, file=log)
                print("win", win, "lr", lr, 'pro_ten', pro_ten, file=log)

                acc_std = np.sqrt(np.var(test_acc))
                pre_std = np.sqrt(np.var(test_pre))
                recall_std = np.sqrt(np.var(test_recall))
                f1_std = np.sqrt(np.var(test_f1))
                auc_std = np.sqrt(np.var(test_auc))

                print("***************************************************************", file=log)
                print("acc_std", acc_std, file=log)
                print("pre_std", pre_std, file=log)
                print("recall_std", recall_std, file=log)
                print("f1_std", f1_std, file=log)
                print("auc_std", auc_std, file=log)

                print("dropout", 0.5, file=log)
                print('wd', 0, file=log)
